#include <gtest/gtest.h>
#include "c/ddk/graph/handle_types.h"
#include "c/ddk/graph/graph.h"
#include "c/ddk/graph/operator.h"
#include "c/ddk/graph/attr_value.h"
#include "graph/csrc/resource_manager.h"
#include "c/ddk/graph/context.h"
#include "c/ddk/graph/op_config.h"
#include "graph/operator.h"
#include "graph/shape.h"

#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/attr_utils.h"

constexpr uint32_t INPUT_NUM = 1;
constexpr uint32_t OUTPUT_NUM = 1;

using namespace std;
using namespace hiai;

class UTEST_c_graph : public testing::Test {
public:
    void SetUp()
    {
        resMgr_ = HIAI_IR_ResourceManagerCreate();
    }

    void TearDown()
    {
        HIAI_IR_ResourceManagerDestroy(&resMgr_);
    }
public:
    GraphHandle CreateGraph()
    {
        const char* name = "graph_defalut";
        GraphHandle graph = HIAI_IR_GraphCreate(resMgr_, name);
        return graph;
    }

    OpHandle CreateData()
    {
        const char* name = "data1";
        uint32_t shapeNum = 4;
        int64_t inputShape[4] = { 10, 4, 120, 440 };
        OpHandle dataOp = HIAI_IR_CreatePlaceHodler(
            resMgr_, name, HIAI_Format::HIAI_FORMAT_NCHW, HIAI_DATATYPE_FLOAT32, inputShape, shapeNum);
        return dataOp;
    }

    OpHandle CreateConv(OpHandle dataOp)
    {
        HIAI_IR_Input input[1]; // 根据输入个数确定数组大小
        input[0].op = dataOp;
        input[0].outIndex = 0;

        int64_t stridesShape[] = {2, 2};
        AttrHandle stridesAttr = HIAI_IR_CreateInt64LstAttr(resMgr_, stridesShape, 2);
        HIAI_IR_Params params[1];
        params[0].name = "strides";
        params[0].attrValue = stridesAttr;

        HIAI_IR_OpConfig opConfig;
        opConfig.opType = "Convolution";
        opConfig.opName = "conv1";
        opConfig.input = input;
        opConfig.inputNums = 1;
        opConfig.params = params;
        opConfig.paramsNums = 1;

        OpHandle convOp = HIAI_IR_CreateOp(resMgr_, &opConfig);
        return convOp;
    }

    ResMgrHandle resMgr_ = nullptr;
};

TEST_F(UTEST_c_graph, HIAI_IR_GraphCreate)
{
    GraphHandle graph = CreateGraph();
    EXPECT_NE(graph, nullptr);
}

TEST_F(UTEST_c_graph, HIAI_IR_SetInputs_SetOutputs)
{
    // Step1:创建Graph
    GraphHandle graph = CreateGraph();
    EXPECT_NE(graph, nullptr);

    // Step2:创建Inputs
    //    data1
    //      \    const1
    //       \    /
    //        conv1

    // Step2.1:创建Data节点
    OpHandle dataOp = CreateData();
    EXPECT_NE(dataOp, nullptr);
    // Step2.2:创建Const节点
    OpHandle convOp = CreateConv(dataOp);
    EXPECT_NE(convOp, nullptr);
    // Step2.3:构建inputs和outputs
    ConstOpHandle inputs[INPUT_NUM] = { dataOp };
    ConstOpHandle outputs[OUTPUT_NUM] = { convOp };

    // Step3:执行SetInputs
    HIAI_Status status = HIAI_IR_SetInputs(resMgr_, graph, inputs, INPUT_NUM);
    EXPECT_EQ(status, HIAI_SUCCESS);
    status = HIAI_IR_SetOutputs(resMgr_, graph, outputs, OUTPUT_NUM);
    EXPECT_EQ(status, HIAI_SUCCESS);

    // Step4:校验结果
    // 如何校验Set之后的graph是预期的？光凭返回值判断不太够
}